# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from dataclasses import dataclass

import torch
from tensordict import TensorDict, TensorDictBase, TensorDictParams
from tensordict.nn import dispatch, TensorDictModule
from tensordict.utils import NestedKey

from torchrl.data.tensor_specs import Bounded, Composite, TensorSpec
from torchrl.envs.utils import step_mdp
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import (
    _cache_values,
    _GAMMA_LMBDA_DEPREC_ERROR,
    _reduce,
    _vmap_func,
    default_value_kwargs,
    distance_loss,
    ValueEstimators,
)
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator


class TD3Loss(LossModule):
    """TD3 Loss module.

    Args:
        actor_network (TensorDictModule): the actor to be trained
        qvalue_network (TensorDictModule): a single Q-value network or a list of
            Q-value networks.
            If a single instance of `qvalue_network` is provided, it will be duplicated ``num_qvalue_nets``
            times. If a list of modules is passed, their
            parameters will be stacked unless they share the same identity (in which case
            the original parameter will be expanded).

            .. warning:: When a list of parameters if passed, it will __not__ be compared against the policy parameters
              and all the parameters will be considered as untied.

    Keyword Args:
        bounds (tuple of float, optional): the bounds of the action space.
            Exclusive with action_spec. Either this or ``action_spec`` must
            be provided.
        action_spec (TensorSpec, optional): the action spec.
            Exclusive with bounds. Either this or ``bounds`` must be provided.
        num_qvalue_nets (int, optional): Number of Q-value networks to be
            trained. Default is ``10``.
        policy_noise (:obj:`float`, optional): Standard deviation for the target
            policy action noise. Default is ``0.2``.
        noise_clip (:obj:`float`, optional): Clipping range value for the sampled
            target policy action noise. Default is ``0.5``.
        priority_key (str, optional): Key where to write the priority value
            for prioritized replay buffers. Default is
            `"td_error"`.
        loss_function (str, optional): loss function to be used for the Q-value.
            Can be one of  ``"smooth_l1"``, ``"l2"``,
            ``"l1"``, Default is ``"smooth_l1"``.
        delay_actor (bool, optional): whether to separate the target actor
            networks from the actor networks used for
            data collection. Default is ``True``.
        delay_qvalue (bool, optional): Whether to separate the target Q value
            networks from the Q value networks used
            for data collection. Default is ``True``.
        spec (TensorSpec, optional): the action tensor spec. If not provided
            and the target entropy is ``"auto"``, it will be retrieved from
            the actor.
        separate_losses (bool, optional): if ``True``, shared parameters between
            policy and critic will only be trained on the policy loss.
            Defaults to ``False``, i.e., gradients are propagated to shared
            parameters for both policy and critic losses.
        reduction (str, optional): Specifies the reduction to apply to the output:
            ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
            ``"mean"``: the sum of the output will be divided by the number of
            elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
        deactivate_vmap (bool, optional): whether to deactivate vmap calls and replace them with a plain for loop.
            Defaults to ``False``.

    Examples:
        >>> import torch
        >>> from torch import nn
        >>> from torchrl.data import Bounded
        >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
        >>> from torchrl.modules.tensordict_module.actors import Actor, ProbabilisticActor, ValueOperator
        >>> from torchrl.modules.tensordict_module.common import SafeModule
        >>> from torchrl.objectives.td3 import TD3Loss
        >>> from tensordict import TensorDict
        >>> n_act, n_obs = 4, 3
        >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
        >>> module = nn.Linear(n_obs, n_act)
        >>> actor = Actor(
        ...     module=module,
        ...     spec=spec)
        >>> class ValueClass(nn.Module):
        ...     def __init__(self):
        ...         super().__init__()
        ...         self.linear = nn.Linear(n_obs + n_act, 1)
        ...     def forward(self, obs, act):
        ...         return self.linear(torch.cat([obs, act], -1))
        >>> module = ValueClass()
        >>> qvalue = ValueOperator(
        ...     module=module,
        ...     in_keys=['observation', 'action'])
        >>> loss = TD3Loss(actor, qvalue, action_spec=actor.spec)
        >>> batch = [2, ]
        >>> action = spec.rand(batch)
        >>> data = TensorDict({
        ...      "observation": torch.randn(*batch, n_obs),
        ...      "action": action,
        ...      ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
        ...      ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool),
        ...      ("next", "reward"): torch.randn(*batch, 1),
        ...      ("next", "observation"): torch.randn(*batch, n_obs),
        ...  }, batch)
        >>> loss(data)
        TensorDict(
            fields={
                loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                next_state_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                pred_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                state_action_value_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                target_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)

    This class is compatible with non-tensordict based modules too and can be
    used without recurring to any tensordict-related primitive. In this case,
    the expected keyword arguments are:
    ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and qvalue network
    The return value is a tuple of tensors in the following order:
    ``["loss_actor", "loss_qvalue", "pred_value", "state_action_value_actor", "next_state_value", "target_value",]``.

    Examples:
        >>> import torch
        >>> from torch import nn
        >>> from torchrl.data import Bounded
        >>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator
        >>> from torchrl.objectives.td3 import TD3Loss
        >>> n_act, n_obs = 4, 3
        >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
        >>> module = nn.Linear(n_obs, n_act)
        >>> actor = Actor(
        ...     module=module,
        ...     spec=spec)
        >>> class ValueClass(nn.Module):
        ...     def __init__(self):
        ...         super().__init__()
        ...         self.linear = nn.Linear(n_obs + n_act, 1)
        ...     def forward(self, obs, act):
        ...         return self.linear(torch.cat([obs, act], -1))
        >>> module = ValueClass()
        >>> qvalue = ValueOperator(
        ...     module=module,
        ...     in_keys=['observation', 'action'])
        >>> loss = TD3Loss(actor, qvalue, action_spec=actor.spec)
        >>> _ = loss.select_out_keys("loss_actor", "loss_qvalue")
        >>> batch = [2, ]
        >>> action = spec.rand(batch)
        >>> loss_actor, loss_qvalue = loss(
        ...         observation=torch.randn(*batch, n_obs),
        ...         action=action,
        ...         next_done=torch.zeros(*batch, 1, dtype=torch.bool),
        ...         next_terminated=torch.zeros(*batch, 1, dtype=torch.bool),
        ...         next_reward=torch.randn(*batch, 1),
        ...         next_observation=torch.randn(*batch, n_obs))
        >>> loss_actor.backward()

    """

    @dataclass
    class _AcceptedKeys:
        """Maintains default values for all configurable tensordict keys.

        This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
        default values.

        Attributes:
            action (NestedKey): The input tensordict key where the action is expected.
                Defaults to ``"action"``.
            state_action_value (NestedKey): The input tensordict key where the state action value is expected.
                Will be used for the underlying value estimator. Defaults to ``"state_action_value"``.
            priority (NestedKey): The input tensordict key where the target priority is written to.
                Defaults to ``"td_error"``.
            reward (NestedKey): The input tensordict key where the reward is expected.
                Will be used for the underlying value estimator. Defaults to ``"reward"``.
            done (NestedKey): The key in the input TensorDict that indicates
                whether a trajectory is done. Will be used for the underlying value estimator.
                Defaults to ``"done"``.
            terminated (NestedKey): The key in the input TensorDict that indicates
                whether a trajectory is terminated. Will be used for the underlying value estimator.
                Defaults to ``"terminated"``.
        """

        action: NestedKey = "action"
        state_action_value: NestedKey = "state_action_value"
        priority: NestedKey = "td_error"
        reward: NestedKey = "reward"
        done: NestedKey = "done"
        terminated: NestedKey = "terminated"
        priority_weight: NestedKey = "priority_weight"

    tensor_keys: _AcceptedKeys
    default_keys = _AcceptedKeys
    default_value_estimator = ValueEstimators.TD0
    out_keys = [
        "loss_actor",
        "loss_qvalue",
        "pred_value",
        "state_action_value_actor",
        "next_state_value",
        "target_value",
    ]

    actor_network: TensorDictModule
    qvalue_network: TensorDictModule
    actor_network_params: TensorDictParams
    qvalue_network_params: TensorDictParams
    target_actor_network_params: TensorDictParams
    target_qvalue_network_params: TensorDictParams

    def __init__(
        self,
        actor_network: TensorDictModule,
        qvalue_network: TensorDictModule | list[TensorDictModule],
        *,
        action_spec: TensorSpec = None,
        bounds: tuple[float] | None = None,
        num_qvalue_nets: int = 2,
        policy_noise: float = 0.2,
        noise_clip: float = 0.5,
        loss_function: str = "smooth_l1",
        delay_actor: bool = True,
        delay_qvalue: bool = True,
        gamma: float | None = None,
        priority_key: str | None = None,
        separate_losses: bool = False,
        reduction: str | None = None,
        deactivate_vmap: bool = False,
        use_prioritized_weights: str | bool = "auto",
    ) -> None:
        if reduction is None:
            reduction = "mean"
        super().__init__()
        self.use_prioritized_weights = use_prioritized_weights
        self._in_keys = None
        self._set_deprecated_ctor_keys(priority=priority_key)

        self.delay_actor = delay_actor
        self.delay_qvalue = delay_qvalue
        self.deactivate_vmap = deactivate_vmap

        self.convert_to_functional(
            actor_network,
            "actor_network",
            create_target_params=self.delay_actor,
        )
        if separate_losses:
            # we want to make sure there are no duplicates in the params: the
            # params of critic must be refs to actor if they're shared
            policy_params = list(actor_network.parameters())
        else:
            policy_params = None
        self.convert_to_functional(
            qvalue_network,
            "qvalue_network",
            num_qvalue_nets,
            create_target_params=self.delay_qvalue,
            compare_against=policy_params,
        )

        for p in self.parameters():
            device = p.device
            break
        else:
            device = None
        self.num_qvalue_nets = num_qvalue_nets
        self.loss_function = loss_function
        self.policy_noise = policy_noise
        self.noise_clip = noise_clip
        if not ((action_spec is not None) ^ (bounds is not None)):
            raise ValueError(
                "One of 'bounds' and 'action_spec' must be provided, "
                f"but not both or none. Got bounds={bounds} and action_spec={action_spec}."
            )
        elif action_spec is not None:
            if isinstance(action_spec, Composite):
                if (
                    isinstance(self.tensor_keys.action, tuple)
                    and len(self.tensor_keys.action) > 1
                ):
                    action_container_shape = action_spec[
                        self.tensor_keys.action[:-1]
                    ].shape
                else:
                    action_container_shape = action_spec.shape
                action_spec = action_spec[self.tensor_keys.action][
                    (0,) * len(action_container_shape)
                ]
            if not isinstance(action_spec, Bounded):
                raise ValueError(
                    f"action_spec is not of type Bounded but {type(action_spec)}."
                )
            low = action_spec.space.low
            high = action_spec.space.high
        else:
            low, high = bounds
        if not isinstance(low, torch.Tensor):
            low = torch.tensor(low)
        if not isinstance(high, torch.Tensor):
            high = torch.tensor(high, device=low.device, dtype=low.dtype)
        if (low > high).any():
            raise ValueError("Got a low bound higher than a high bound.")
        if device is not None:
            low = low.to(device)
            high = high.to(device)
        self.register_buffer("max_action", high)
        self.register_buffer("min_action", low)
        if gamma is not None:
            raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
        self._make_vmap()
        self.reduction = reduction

    def _make_vmap(self):
        self._vmap_qvalue_network00 = _vmap_func(
            self.qvalue_network,
            randomness=self.vmap_randomness,
            pseudo_vmap=self.deactivate_vmap,
        )
        self._vmap_actor_network00 = _vmap_func(
            self.actor_network,
            randomness=self.vmap_randomness,
            pseudo_vmap=self.deactivate_vmap,
        )

    def _forward_value_estimator_keys(self, **kwargs) -> None:
        if self._value_estimator is not None:
            self._value_estimator.set_keys(
                value=self._tensor_keys.state_action_value,
                reward=self.tensor_keys.reward,
                done=self.tensor_keys.done,
                terminated=self.tensor_keys.terminated,
            )
        self._set_in_keys()

    def _set_in_keys(self):
        keys = [
            self.tensor_keys.action,
            ("next", self.tensor_keys.reward),
            ("next", self.tensor_keys.done),
            ("next", self.tensor_keys.terminated),
            *self.actor_network.in_keys,
            *[("next", key) for key in self.actor_network.in_keys],
            *self.qvalue_network.in_keys,
        ]
        self._in_keys = list(set(keys))

    @property
    def in_keys(self):
        if self._in_keys is None:
            self._set_in_keys()
        return self._in_keys

    @in_keys.setter
    def in_keys(self, values):
        self._in_keys = values

    @property
    @_cache_values
    def _cached_detach_qvalue_network_params(self):
        return self.qvalue_network_params.detach()

    @property
    @_cache_values
    def _cached_stack_actor_params(self):
        return torch.stack(
            [self.actor_network_params, self.target_actor_network_params], 0
        )

    def actor_loss(self, tensordict) -> tuple[torch.Tensor, dict]:
        weights = self._maybe_get_priority_weight(tensordict)
        tensordict_actor_grad = tensordict.select(
            *self.actor_network.in_keys, strict=False
        )
        with self.actor_network_params.to_module(self.actor_network):
            tensordict_actor_grad = self.actor_network(tensordict_actor_grad)
        actor_loss_td = tensordict_actor_grad.select(
            *self.qvalue_network.in_keys, strict=False
        ).expand(
            self.num_qvalue_nets, *tensordict_actor_grad.batch_size
        )  # for actor loss
        state_action_value_actor = (
            self._vmap_qvalue_network00(
                actor_loss_td,
                self._cached_detach_qvalue_network_params,
            )
            .get(self.tensor_keys.state_action_value)
            .squeeze(-1)
        )
        loss_actor = -(state_action_value_actor[0])
        metadata = {
            "state_action_value_actor": state_action_value_actor.detach(),
        }
        loss_actor = _reduce(loss_actor, reduction=self.reduction, weights=weights)
        self._clear_weakrefs(
            tensordict,
            "actor_network_params",
            "qvalue_network_params",
            "target_actor_network_params",
            "target_qvalue_network_params",
        )
        return loss_actor, metadata

    def value_loss(self, tensordict) -> tuple[torch.Tensor, dict]:
        weights = self._maybe_get_priority_weight(tensordict)
        tensordict = tensordict.clone(False)

        act = tensordict.get(self.tensor_keys.action)

        # computing early for reprod
        noise = (torch.randn_like(act) * self.policy_noise).clamp(
            -self.noise_clip, self.noise_clip
        )

        with torch.no_grad():
            next_td_actor = step_mdp(tensordict).select(
                *self.actor_network.in_keys, strict=False
            )  # next_observation ->
            with self.target_actor_network_params.to_module(self.actor_network):
                next_td_actor = self.actor_network(next_td_actor)
            next_action = (next_td_actor.get(self.tensor_keys.action) + noise).clamp(
                self.min_action, self.max_action
            )
            next_td_actor.set(
                self.tensor_keys.action,
                next_action,
            )
            next_val_td = next_td_actor.select(
                *self.qvalue_network.in_keys, strict=False
            ).expand(
                self.num_qvalue_nets, *next_td_actor.batch_size
            )  # for next value estimation
            next_target_q1q2 = (
                self._vmap_qvalue_network00(
                    next_val_td,
                    self.target_qvalue_network_params,
                )
                .get(self.tensor_keys.state_action_value)
                .squeeze(-1)
            )
        # min over the next target qvalues
        next_target_qvalue = next_target_q1q2.min(0)[0]

        # set next target qvalues
        tensordict.set(
            ("next", self.tensor_keys.state_action_value),
            next_target_qvalue.unsqueeze(-1),
        )

        qval_td = tensordict.select(*self.qvalue_network.in_keys, strict=False).expand(
            self.num_qvalue_nets,
            *tensordict.batch_size,
        )
        # preditcted current qvalues
        current_qvalue = (
            self._vmap_qvalue_network00(
                qval_td,
                self.qvalue_network_params,
            )
            .get(self.tensor_keys.state_action_value)
            .squeeze(-1)
        )

        # compute target values for the qvalue loss (reward + gamma * next_target_qvalue * (1 - done))
        target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1)

        td_error = (current_qvalue - target_value).pow(2)
        loss_qval = distance_loss(
            current_qvalue,
            target_value.expand_as(current_qvalue),
            loss_function=self.loss_function,
        ).sum(0)
        metadata = {
            "td_error": td_error,
            "next_state_value": next_target_qvalue.detach(),
            "pred_value": current_qvalue.detach(),
            "target_value": target_value.detach(),
        }
        loss_qval = _reduce(loss_qval, reduction=self.reduction, weights=weights)
        self._clear_weakrefs(
            tensordict,
            "actor_network_params",
            "qvalue_network_params",
            "target_actor_network_params",
            "target_qvalue_network_params",
        )
        return loss_qval, metadata

    @dispatch
    def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
        tensordict_save = tensordict
        loss_actor, metadata_actor = self.actor_loss(tensordict)
        loss_qval, metadata_value = self.value_loss(tensordict_save)
        tensordict_save.set(
            self.tensor_keys.priority, metadata_value.pop("td_error").detach().max(0)[0]
        )
        if not loss_qval.shape == loss_actor.shape:
            raise RuntimeError(
                f"QVal and actor loss have different shape: {loss_qval.shape} and {loss_actor.shape}"
            )
        td_out = TensorDict(
            loss_actor=loss_actor,
            loss_qvalue=loss_qval,
            **metadata_actor,
            **metadata_value,
        )
        self._clear_weakrefs(
            tensordict,
            "actor_network_params",
            "qvalue_network_params",
            "target_actor_network_params",
            "target_qvalue_network_params",
        )
        return td_out

    def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
        if value_type is None:
            value_type = self.default_value_estimator
        self.value_type = value_type
        hp = dict(default_value_kwargs(value_type))
        if hasattr(self, "gamma"):
            hp["gamma"] = self.gamma
        hp.update(hyperparams)
        # we do not need a value network bc the next state value is already passed
        if value_type == ValueEstimators.TD1:
            self._value_estimator = TD1Estimator(value_network=None, **hp)
        elif value_type == ValueEstimators.TD0:
            self._value_estimator = TD0Estimator(value_network=None, **hp)
        elif value_type == ValueEstimators.GAE:
            raise NotImplementedError(
                f"Value type {value_type} it not implemented for loss {type(self)}."
            )
        elif value_type == ValueEstimators.TDLambda:
            self._value_estimator = TDLambdaEstimator(value_network=None, **hp)
        else:
            raise NotImplementedError(f"Unknown value type {value_type}")

        tensor_keys = {
            "value": self.tensor_keys.state_action_value,
            "reward": self.tensor_keys.reward,
            "done": self.tensor_keys.done,
            "terminated": self.tensor_keys.terminated,
        }
        self._value_estimator.set_keys(**tensor_keys)
